package com.sfzd5.StudyJavaCV.mnist;

import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.opencv_core;
import org.bytedeco.javacpp.opencv_ml;
import org.opencv.core.Core;

import static com.sfzd5.StudyJavaCV.mnist.MnistRead.*;
import static org.bytedeco.javacpp.opencv_core.CV_32FC1;
import static org.bytedeco.javacpp.opencv_core.minMaxLoc;
import static org.bytedeco.javacpp.opencv_ml.ROW_SAMPLE;
import static org.bytedeco.javacpp.opencv_ml.StatModel.UPDATE_MODEL;

public class MnistMlp {

    /**
     * 训练数据
     * @param xml 要保存的数据文件名
     */
    public static void train(String xml){
        opencv_core.Mat trainData = MnistRead.getTrainData(TRAIN_IMAGES_FILE);
        opencv_core.Mat lables = MnistRead.getTrainLabels(TRAIN_LABELS_FILE);

        opencv_ml.ANN_MLP mlp= opencv_ml.ANN_MLP.create();

        int image_cols = 28; //图片宽
        int image_rows = 28; //图片高
        int class_num = 10; //预测的结果，为 float[10] 数组
        /*
         * 神经网络层
         * */
        int[] layer={ image_cols*image_rows , 512, 256, class_num};
        opencv_core.Mat layerSizes=new opencv_core.Mat(1, layer.length, CV_32FC1);
        org.bytedeco.javacpp.indexer.FloatIndexer indexer = layerSizes.createIndexer();
        for(int i=0;i<layer.length;i++){
            indexer.put(i, layer[i]);
        }

        mlp.setLayerSizes(layerSizes);
        mlp.setActivationFunction(opencv_ml.ANN_MLP.SIGMOID_SYM);
        mlp.train(trainData, ROW_SAMPLE, lables);
        /*
         * 开始训练
         * */

        mlp.save(xml);
        mlp.clear();
        System.out.println("训练结束");

    }

    /**
     * 识别数字图片
     * @param xml 训练好的数据文件
     * @param sample = new Mat(1, 28*28, CvType.CV_32FC1); 预测的对象
     * @return 返回预测的数字
     */
    public static int predict(String xml, opencv_core.Mat sample ){
        opencv_ml.ANN_MLP ann = opencv_ml.ANN_MLP.load(xml);
        opencv_core.Mat predict = new opencv_core.Mat();
        ann.predict(sample, predict, UPDATE_MODEL);
        return getMaxIndex(predict);
    }

    /**
     * 使用测试数据，测试识别率
     * @param xml 训练好的数据文件
     */
    public static void test(String xml){
        opencv_ml.ANN_MLP ann = opencv_ml.ANN_MLP.load(xml);

        opencv_core.Mat predictData = MnistRead.getTrainData(TEST_IMAGES_FILE);
        byte[] predictLables = MnistRead.getLabels(TEST_LABELS_FILE);

        //正确计数
        int rc = 0;

        for(int i=0; i<predictData.rows(); i++){
            opencv_core.Mat sample = predictData.row(i);
            opencv_core.Mat predict = new opencv_core.Mat();
            ann.predict(sample, predict, UPDATE_MODEL);
            if(predictLables[i] == getMaxIndex(predict)){
                //预测正确
                rc++;
            }
        }

        //计算正确率
        double zql = rc*1.0/predictData.rows();
        System.out.println("正确率：" + zql);
    }

    /**
     * 预测返回数据为一个包含10个数据的数组，和训练时对应的。找出最大值的位置，就是预测的数字
     * @param predict 预测返回的Mat
     * @return
     */
    private static int getMaxIndex(opencv_core.Mat predict) {
        FloatIndexer indexer = predict.createIndexer();
        int v = -1;
        float p = -1;
        for (int c = 0; c < predict.cols(); c++) {
            float f = indexer.get(c);
            if (f > p) {
                p = f;
                v = c;
            }
        }
        return v;
    }
}
